import torch
import lietorch
import numpy as np
from factor_graph import FactorGraph
from geom.ba import InitializeGravityDirectionDynamic, InitializeVeloBiasGdir, InitializeFullInertialBA, BA_prepare
from collections import defaultdict

class TrackFrontend:
    def __init__(self, net, video, config, args):
        self.video = video
        self.update_op = net.update
        self.graph = FactorGraph(video, net.update, max_factors=48)
        self.args = args
        # local optimization window
        self.t1 = 0

        self.delete_count = defaultdict(int)
        
        # frontent variables
        self.max_age = 25
        self.iters1 = 4
        self.iters2 = 2
        self.warmup = 10

        self.first_pure_tracking = config.get("first_pure_tracking", False)
        self.frontend_nms = config["frontend_nms"]
        self.keyframe_thresh = config["keyframe_thresh"]
        self.frontend_window = config["frontend_window"]
        self.frontend_thresh = config["frontend_thresh"]
        self.frontend_radius = config["frontend_radius"]
        self.video.mono_depth_alpha = config["mono_depth_alpha"]
        
        self.Tcb = args.Tcb
        self.init_bg = args.init_bg
        self.init_ba = args.init_ba
        self.init_g = args.init_g
        
        self.scale = 1
        self.imu_init_fix_scale = False  #should be False since IMU can provide metric scale
        self.imu_late_init_from = config["imu_late_init_from"]

    def __graph_update(self, iters, t0=None, t1=None, use_inactive=False, use_mono=False, mode='inertial', disable_vision=False):
        if "initial_inertial" not in mode:
            for i in range(iters):
                self.graph.update(t0, t1, use_inactive=use_inactive, use_mono=use_mono, inertial=('inertial' in mode), tracking=('tracking' in mode), disable_vision=disable_vision)
        elif mode == "initial_inertial":
            print(self.t1)
            t0=0
            poses_cw = lietorch.SE3(self.video.poses[:self.t1][None])
            poses_bw = self.Tcb.inv() * poses_cw
            velos_w = self.video.velos_w[:self.t1].unsqueeze(0)
            biass_w = self.video.biass_w[:self.t1].unsqueeze(0)
            Rwg = self.video.Rwg
            fix_front = 0   # fix the velocity of the very first frame as zero vector
            mean_g = None

            # step 1
            if Rwg is None:
                Rwg = np.eye(3)
                print("="*100, "\nInitial gravity direction from frame '{}' to '{}'".format(t0, self.t1))
                for itr in range(iters):
                    Rwg = InitializeGravityDirectionDynamic(t0, self.t1, poses_bw, velos_w, biass_w, self.video.preints, Rwg)
                print('- - gravity vector at Cam frame', Rwg @ self.init_g)
                print('- - R from g to camera', Rwg)

            # step 2
            print("="*100, "\nInitial velocity, bias, gdir and scale from frame '{}' to '{}'".format(t0, self.t1))
            scale = 1
            init_bias_scale = 1 # should just be 1, unless debugging
            for _ in range(iters):
                velos_w, biass_w, Rwg, scale = InitializeVeloBiasGdir(t0, self.t1, poses_bw, velos_w, biass_w, self.video.preints, Rwg, self.init_g, self.Tcb, mean_g, scale, self.imu_init_fix_scale, fix_front=fix_front, bias_scale=init_bias_scale)
            self.video.velos_w[:self.t1] = velos_w[0]
            self.video.biass_w[:self.t1] = biass_w[0]
            self.video.reintegrate_all()
            print('- - mean bias', list(torch.mean(biass_w, dim=1).cpu().numpy()))
            print('- - R from g to world', Rwg)
            print('- - Bias at Imu frame', biass_w[0, 0, 3:])
            print('- - scale', scale)
            for i in range(0, self.t1):
                print('- - - velo', i, list(self.video.velos_w[i].cpu().numpy()))
                print('- - - bias', i, list(self.video.biass_w[i].cpu().numpy()))

            # step 3
            disps = self.video.disps[:self.t1][None]
            intrs = self.video.intrinsics[:self.t1][None]
            print("="*100, "\nInitial full inertial BA from frame '{}' to '{}'".format(t0, self.t1))
            
            newgraph = FactorGraph(self.video, self.graph.update_op, corr_impl="alt", max_factors=1000)
            newgraph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False)
            
            for itr in range(iters):
                t0, target, weight, eta, ii, jj, opt_ii, upmask = newgraph.get_network_update_full_graph(t0=t0)
                
                for _ in range(2):
                    H, E, C, v, w, chi2, chi2R = BA_prepare(target, weight, eta, poses_bw, disps, intrs[:,:,:], ii, jj, self.video.Tcb, fixedp=0, D=15, t0=t0, t1=self.t1)
                    print("- - Chi2 error re-proj raw/robust: {:.4f} {:.4f}".format((chi2).item(), (chi2R).item()))
                    print(H.shape, E.shape, C.shape, v.shape, w.shape)
                    poses_bw, velos_w, biass_w, disps, _, _, scale = InitializeFullInertialBA(t0, self.t1, poses_bw, velos_w, biass_w, disps, None, None,
                                                                                                ii, self.video.preints, Rwg, scale, H, E, C, v, w, imu_init_fix_scale=self.imu_init_fix_scale, fix_front=15, bias_scale=init_bias_scale)
                    poses_cw = self.Tcb * poses_bw
                # must update each iter
                self.video.poses[:self.t1] = poses_cw.data[0]
                self.video.disps[:self.t1] = disps[0]
            
            self.video.upsample(torch.unique(opt_ii), upmask) # upsample the disp, using upmask
            self.video.velos_w[:self.t1] = velos_w[0]
            self.video.biass_w[:self.t1] = biass_w[0]
            self.video.reintegrate_all()
            self.video.Rwg = Rwg
            print('- - scale', scale)
            if not self.imu_init_fix_scale: # 
                self.scale = scale
                self.video.rescale(scale, self.t1)

            self.video.dirty[:self.t1] = True

            # debug output
            for i in range(0, self.t1):
                print('- - - velo', i, list(self.video.velos_w[i].cpu().numpy()))
                print('- - - bias', i, list(self.video.biass_w[i].cpu().numpy()))
            print('- - mean bias', list(torch.mean(self.video.biass_w[:self.t1], dim=0).cpu().numpy()))
            print("- - gravity direction", Rwg)
            print("- - initialization finished!!!!")
            self.video.IMU_initialized = True
        else:
            raise ValueError(f"Invalid mode: {mode}")
    
    def __update(self, is_last):
        """ add edges, perform update """

        self.t1 += 1

        if self.graph.corr is not None:
            self.graph.rm_factors(self.graph.age > self.max_age, store=True)

        self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0), 
            rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, remove=True)

        self.video.dscales[self.t1-1] = self.video.disps[self.t1-1].median() / self.video.disps_prior[self.t1-1].median()
        
        
        for itr in range(self.iters1):
            # tracking here is for speed(reduce computational cost)
            if self.first_pure_tracking:
                suffix='_tracking'
            else:
                suffix=''
            if self.t1 > self.imu_late_init_from:
                self.__graph_update(1, None, None, use_inactive=True, use_mono=itr>1, mode='inertial'+suffix, disable_vision=False)
            else:
                # vision_only_tracking just to save computational cost, also fine to use vision_only
                self.__graph_update(1, None, None, use_inactive=True, use_mono=itr>1, mode='vision_only_tracking')
                self.__graph_update(1, None, None, use_inactive=True, use_mono=itr>1, mode='vision_only'+suffix)

        d = self.video.distance([self.t1-3], [self.t1-2], bidirectional=True)
        
        self.keep_kf_once_deleted = True
        already_skip = (self.keep_kf_once_deleted and self.delete_count[self.t1-2] > 0)
        dT = (self.video.kf_stamps[self.video.counter.value-1] - self.video.kf_stamps[self.video.counter.value-3])
        criteria =  d.item() < self.keyframe_thresh and not already_skip and (dT < 3)
        
        # NOTE: this can enforce no keyframe deletion    
        # if False and criteria:
        
        if criteria:
            
            self.graph.rm_keyframe(self.t1 - 2)
            self.delete_count[self.t1-2] += 1
            with self.video.get_lock():
                self.video.counter.value -= 1
                self.t1 -= 1
            update_idx = []
        else:
            if self.t1 > self.imu_late_init_from:
                for itr in range(self.iters2):
                    self.__graph_update(1, None, None, use_inactive=True, mode="inertial", disable_vision=False)
            else:
                for itr in range(self.iters2):
                    self.__graph_update(1, None, None, use_inactive=True, mode="vision_only")

            if self.t1 == self.imu_late_init_from:
                self.__graph_update(5, None, None, use_inactive=True, mode="vision_only")
                print("begin initial_inertial")
                self.__graph_update(8, t0=1, t1=None, use_inactive=True, mode="initial_inertial")
                

            # not to use the newest keyframe directly for mapping, since it is not stable. But if it is already the last keyframe, then use it        
            if is_last:
                update_idx = torch.arange(self.graph.ii.min(), self.t1, device='cuda')
            else:
                update_idx = torch.arange(self.graph.ii.min(), self.t1-1, device='cuda')

        # set pose for next itration
        self.video.poses[self.t1] = self.video.poses[self.t1-1]
        self.video.disps[self.t1] = self.video.disps[self.t1-1].mean()
        self.video.velos_w[self.t1] = self.video.velos_w[self.t1-1]
        self.video.biass_w[self.t1] = self.video.biass_w[self.t1-1]
        
        # update visualization
        self.video.dirty[self.graph.ii.min():self.t1] = True
        
        if self.t1 == self.imu_late_init_from:
            update_idx = torch.arange(0, self.t1-1, device='cuda')
            self.video.dirty[0:self.t1] = True
            if hasattr(self.video, 'gs') and self.video.gs is not None:
                self.video.gs.remove_all_gaussians()
        return update_idx
    
    def __initialize(self):
        """ initialize the SLAM system """

        self.t1 = self.video.counter.value

        # initial optimization
        self.graph.add_neighborhood_factors(0, self.t1, r=3)
        for itr in range(8):
            self.__graph_update(1, 1, use_inactive=True, use_mono=False, mode="vision_only")
            
        # refine optimization
        self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False)
        for i in range(self.t1):
            self.video.dscales[i] = self.video.disps[i].median() / self.video.disps_prior[i].median()
        for itr in range(8):
            self.__graph_update(1, 1, use_inactive=True, use_mono=itr>2, mode="vision_only")

        
        self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False)
        for itr in range(8):
            self.__graph_update(1, 1, use_inactive=True, use_mono=itr>2, mode="vision_only")

        # useful for e.g. fastlivo
        self.video.normalize()

        # initialization complete
        self.video.is_initialized = True
        self.video.poses[self.t1] = self.video.poses[self.t1-1].clone()
        self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean()
        self.video.velos_w[self.t1] = self.video.velos_w[self.t1-1]
        self.video.biass_w[self.t1] = self.video.biass_w[self.t1-1]
        
        with self.video.get_lock():
            self.video.dirty[:self.t1] = True
        self.graph.rm_factors(self.graph.ii < self.t1-4, store=True)
        return torch.arange(self.t1-1, device='cuda')

    def __call__(self, is_last):
        """ main update """
        self.to_update = []

        # do initialization
        if not self.video.is_initialized and self.video.counter.value == self.warmup:
            self.to_update = self.__initialize()
            
        # do update
        elif self.video.is_initialized and self.t1 < self.video.counter.value:
            self.to_update = self.__update(is_last)
        
        return self.to_update
